# yapf: disable
import pytest

from garage.envs import GymEnv, normalize
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import LocalSampler
from garage.tf.algos import TRPO
from garage.tf.optimizers import (ConjugateGradientOptimizer,
                                  FiniteDifferenceHVP)
from garage.tf.policies import (GaussianGRUPolicy,
                                GaussianLSTMPolicy,
                                GaussianMLPPolicy)
from garage.trainer import TFTrainer

from tests.fixtures import snapshot_config, TfGraphTestCase

# yapf: enable

policies = [GaussianGRUPolicy, GaussianLSTMPolicy, GaussianMLPPolicy]


class TestGaussianPolicies(TfGraphTestCase):

    @pytest.mark.parametrize('policy_cls', policies)
    def test_gaussian_policies(self, policy_cls):
        with TFTrainer(snapshot_config, sess=self.sess) as trainer:
            env = normalize(GymEnv('Pendulum-v0'))

            policy = policy_cls(name='policy', env_spec=env.spec)

            baseline = LinearFeatureBaseline(env_spec=env.spec)

            algo = TRPO(
                env_spec=env.spec,
                policy=policy,
                baseline=baseline,
                discount=0.99,
                max_kl_step=0.01,
                optimizer=ConjugateGradientOptimizer,
                optimizer_args=dict(hvp_approach=FiniteDifferenceHVP(
                    base_eps=1e-5)),
            )

            trainer.setup(algo, env, sampler_cls=LocalSampler)
            trainer.train(n_epochs=1, batch_size=4000)
            env.close()
